// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <fstream>
#include <iostream>

#include "test/common/GoldModel.h"
#include "test/common/OperationMapper.h"
#include "test/common/Utils.h"
#include "test/common/operation.proto.h"

void run_op(const Params params, const VectorParams vectorParams,
            INPUT_DATATYPE *mainMemory);


void InitMemories(
    const third_party::dnn_accelerator_hls::test::common::Operation &operation,
    Params &params, VectorParams &vectorParams,
    INPUT_DATATYPE *&acceleratorMemory, INPUT_DATATYPE *&matrixA,
    int &matrixASize, INPUT_DATATYPE *&matrixB, int &matrixBSize,
    INPUT_DATATYPE *&matrixC, int &matrixCSize) {
  acceleratorMemory = new INPUT_DATATYPE[4 * 1024 * 1024];

  if (operation.has_tensor_operation()) {
    int X, Y, C, K, FX, FY;
    if (operation.tensor_operation().has_matrix_multiplication()) {
      const third_party::dnn_accelerator_hls::test::common::
          MatrixMultiplicationOperation matMulOp =
              operation.tensor_operation().matrix_multiplication();
      X = matMulOp.ay();
      Y = 1;
      C = matMulOp.ax();
      K = matMulOp.bx();
      FX = 1;
      FY = 1;
    } else {
      const third_party::dnn_accelerator_hls::test::common::ConvolutionOperation
          convOp = operation.tensor_operation().convolution();
      X = convOp.ox();
      Y = convOp.oy();
      C = convOp.ic();
      K = convOp.oc();
      FX = convOp.fx();
      FY = convOp.fy();
    }

    matrixASize = X * Y * C;
    matrixA = new INPUT_DATATYPE[matrixASize];
    params.INPUT_OFFSET = 0;

    for (int y = 0; y < Y; y++) {
      for (int x = 0; x < X; x++) {
        for (int c = 0; c < C; c++) {
          int val = y * X * C + x * C + c;
          int address = y * X * C + x * C + c;

          acceleratorMemory[params.INPUT_OFFSET + address] = val;
          matrixA[address] = val;
        }
      }
    }

    matrixBSize = FX * FY * C * K;
    matrixB = new INPUT_DATATYPE[matrixBSize];
    params.WEIGHT_OFFSET = matrixASize;

    for (int fy = 0; fy < FY; fy++) {
      for (int fx = 0; fx < FX; fx++) {
        for (int c = 0; c < C; c++) {
          for (int k = 0; k < K; k++) {
            int val = fy * FX * C * K + fx * C * K + c * K + k;
            int address = fy * FX * C * K + fx * C * K + c * K + k;

            acceleratorMemory[params.WEIGHT_OFFSET + address] = val;
            matrixB[address] = val;
          }
        }
      }
    }

    matrixCSize = X * Y * K;
    matrixC = new OUTPUT_DATATYPE[matrixCSize];
    params.OUTPUT_OFFSET = params.WEIGHT_OFFSET + matrixBSize;
  } else {
    int X, Y;
    if (operation.has_softmax()) {
      Y = operation.softmax().tensor_height();
      X = operation.softmax().tensor_width();
    } else {
      Y = operation.layer_norm().tensor_height();
      X = operation.layer_norm().tensor_width();
    }

    matrixASize = X * Y;
    matrixA = new INPUT_DATATYPE[matrixASize];
    vectorParams.VECTOR_OFFSET = 0;
    for (int y = 0; y < Y; y++) {
      for (int x = 0; x < X; x++) {
        int val = y * X + x;
        int address = y * X + x;

        acceleratorMemory[params.INPUT_OFFSET + address] = val;
        matrixA[address] = val;
      }
    }

    matrixBSize = 0;

    matrixCSize = X * Y;
    matrixC = new OUTPUT_DATATYPE[matrixCSize];
    vectorParams.OUTPUT_OFFSET = vectorParams.VECTOR_OFFSET;
  }
}

int sc_main(int argc, char *argv[]) {
  third_party::dnn_accelerator_hls::test::common::Operation operation;

  std::string prefix("--proto=");
  for (int i = 1; i < argc; i++) {
    std::string arg(argv[i]);
    if (arg.rfind(prefix, 0) == 0) {
      std::string path_to_proto = arg.substr(prefix.size());
      std::fstream inputFile(argv[1], ios::in | ios::binary);
      operation.ParseFromIstream(&inputFile);
      break;
    }
  }


  Params params;
  VectorParams vectorParams;
  MapOperationToAccelerator(operation, params, vectorParams);

  INPUT_DATATYPE *acceleratorMemory = nullptr;
  INPUT_DATATYPE *matrixA = nullptr;
  INPUT_DATATYPE *matrixB = nullptr;
  INPUT_DATATYPE *matrixC = nullptr;

  int matrixASize, matrixBSize, matrixCSize;

  InitMemories(operation, params, vectorParams, acceleratorMemory, matrixA,
               matrixASize, matrixB, matrixBSize, matrixC, matrixCSize);

  run_op(params, vectorParams, acceleratorMemory);
  run_gold_op(operation, matrixA, matrixB, matrixC);

  compare_arrays(&acceleratorMemory[params.OUTPUT_OFFSET], matrixC,
                 matrixCSize);

  delete[] matrixA;
  delete[] matrixB;
  delete[] matrixC;
  delete[] acceleratorMemory;

  return 1;
}
